Amazon SageMakerのLinearLearnerが多値分類に対応したのでやってみた
こんにちは、小澤です。
Amazon SageMaker(以下SageMaker)のLinearLearnerが多値分類に対応していたので、やってみました。
学習させてみる
LinearLearnerで多値分類をするための設定は以下のようになります。 ここでは、分類するクラス数をnとしています。
- 正解ラベルとなる値を0, 1,..., n-1の数値にする
- predictor_typeをmulticlass_classifierに設定する
- num_classesでnを設定する
今回は、SageMakerのサンプルにある Introduction to Amazon algorithms > linear_learner_mnist のをベースに二値分類から多値分類に変更してみます。
このサンプルではData conversionでmnistの正解ラベルが0の時は0, それ以外の時は1という変換を行っています。 これは書かれている数字が0かそれ以外かという二値分類を行うためです。
labels = np.where(np.array([t.tolist() for t in train_set[1]]) == 0, 1, 0).astype('float32')
これを多値分類を使って書かれている数字をそのまま予測するように正解ラベルの値を設定してみます。
labels = np.array(np.array([t.tolist() for t in train_set[1]])).astype('float32')
mnistのデータは0~9の10種類の数字を予測するものになっています。 LinearLearnerの多値分類では、n種類の分類の場合0からn-1までのn種類の値を各ラベル値として設定するので、書かれている数字をそのまま正解ラベルにしています。
続いて、Training the linear modelでハイパーパラメータを以下のように設定します。
linear.set_hyperparameters(feature_dim=784, num_classes=10, predictor_type='multiclass_classifier', mini_batch_size=200)
feature_dimとmini_batch_sizeの値はそのまま、predictor_typeをmulticlass_classifierに変更してnum_classesを新たに追加しています。 今回は10種類の分類なので、num_classesは10に設定しています。
これで学習させれば、多値分類での学習が行えます。 マネジメントコンソールのトレーニングジョブから設定てしたハイパーパラメータの設定が確認できます。
結果を確認してみる
では、続いて予測を行ってみます。 こちらも同様のサンプルをベースに確認していきます。
予測を行うに祭してのコードの変更は必要ありません。 サンプルのものをそのまま実行します。
result = linear_predictor.predict(train_set[0][30:31]) print(result)
結果は以下のようになります。
{'predictions': [ {'score': [ 1.000000013351432e-10, 7.719642599113286e-05, 0.00023102026898413897, 0.9992199540138245, 2.4312312234542333e-06, 3.6111225199420005e-05, 3.700147033214307e-07, 0.00016467596287839115, 0.0002673699054867029, 8.82955305314681e-07], 'predicted_label': 3.0 }]}
scoreに各ラベルごとの確率、predicted_labelにラベルの予測値が入っています。
複数のデータに対して同時に実行することも可能なので、以下のようにするといくつかのデータで予測結果の確認ができます。
for i, j in zip(train_set[1][30:40], linear_predictor.predict(train_set[0][30:40])['predictions']): print('actuals : {}, predictions : {}'.format(i, j['predicted_label']))
結果は以下のようになります。
actuals : 3, predictions : 3.0 actuals : 8, predictions : 8.0 actuals : 6, predictions : 6.0 actuals : 9, predictions : 9.0 actuals : 0, predictions : 0.0 actuals : 5, predictions : 5.0 actuals : 6, predictions : 6.0 actuals : 0, predictions : 0.0 actuals : 7, predictions : 7.0 actuals : 6, predictions : 6.0
最後にサンプルで混同行列を作成している部分も多値分類に合わせて書き換えてみましょう。 サンプルにてtest_setデータに対して予測を行う部分処理の部分に変更はありません。 混同行列の作成を行う以下の部分を変更します。
pd.crosstab(np.where(test_set[1] == 0, 1, 0), predictions, rownames=['actuals'], colnames=['predictions'])
変更内容は以下のように、0と1への変換を行う部分を除くのみになります。
pd.crosstab(test_set[1], predictions, rownames=['actuals'], colnames=['predictions'])
結果は以下のようになります。
おわりに
SageMakerのLinearLearnerを使った多値分類をやってみました。 これまで、画像認識以外でSageMakerのBuild-Inアルゴリズムを使った多値分類を行うにはXGBoostのみでした。 XGBoostは非常にいい結果を得られるアルゴリズムですが、線形分類器を使った多値分類はその性質上、需要も多くあるのでこれはなかなか嬉しいアップデートです。